// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. See the AUTHORS file for names of contributors.

#include "db/skiplist.h"

#include <atomic>
#include <set>

#include "leveldb/env.h"
#include "port/port.h"
#include "port/thread_annotations.h"
#include "util/arena.h"
#include "util/hash.h"
#include "util/random.h"
#include "util/testharness.h"

namespace leveldb {

    typedef uint64_t Key;

    struct Comparator {
        int operator()(const Key &a, const Key &b) const {
            if (a < b) {
                return -1;
            } else if (a > b) {
                return +1;
            } else {
                return 0;
            }
        }
    };

    class SkipTest {
    };

    TEST(SkipTest, Empty) {
        Arena arena;
        Comparator cmp;
        SkipList<Key, Comparator> list(cmp, &arena);
        ASSERT_TRUE(!list.Contains(10));

        SkipList<Key, Comparator>::Iterator iter(&list);
        ASSERT_TRUE(!iter.Valid());
        iter.SeekToFirst();
        ASSERT_TRUE(!iter.Valid());
        iter.Seek(100);
        ASSERT_TRUE(!iter.Valid());
        iter.SeekToLast();
        ASSERT_TRUE(!iter.Valid());
    }

    TEST(SkipTest, InsertAndLookup) {
        const int N = 2000;
        const int R = 5000;
        Random rnd(1000);
        std::set<Key> keys;
        Arena arena;
        Comparator cmp;
        SkipList<Key, Comparator> list(cmp, &arena);
        for (int i = 0; i < N; i++) {
            Key key = rnd.Next() % R;
            if (keys.insert(key).second) {
                list.Insert(key);
            }
        }

        for (int i = 0; i < R; i++) {
            if (list.Contains(i)) {
                ASSERT_EQ(keys.count(i), 1);
            } else {
                ASSERT_EQ(keys.count(i), 0);
            }
        }

        // Simple iterator tests
        {
            SkipList<Key, Comparator>::Iterator iter(&list);
            ASSERT_TRUE(!iter.Valid());

            iter.Seek(0);
            ASSERT_TRUE(iter.Valid());
            ASSERT_EQ(*(keys.begin()), iter.key());

            iter.SeekToFirst();
            ASSERT_TRUE(iter.Valid());
            ASSERT_EQ(*(keys.begin()), iter.key());

            iter.SeekToLast();
            ASSERT_TRUE(iter.Valid());
            ASSERT_EQ(*(keys.rbegin()), iter.key());
        }

        // Forward iteration test
        for (int i = 0; i < R; i++) {
            SkipList<Key, Comparator>::Iterator iter(&list);
            iter.Seek(i);

            // Compare against model iterator
            std::set<Key>::iterator model_iter = keys.lower_bound(i);
            for (int j = 0; j < 3; j++) {
                if (model_iter == keys.end()) {
                    ASSERT_TRUE(!iter.Valid());
                    break;
                } else {
                    ASSERT_TRUE(iter.Valid());
                    ASSERT_EQ(*model_iter, iter.key());
                    ++model_iter;
                    iter.Next();
                }
            }
        }

        // Backward iteration test
        {
            SkipList<Key, Comparator>::Iterator iter(&list);
            iter.SeekToLast();

            // Compare against model iterator
            for (std::set<Key>::reverse_iterator model_iter = keys.rbegin();
                 model_iter != keys.rend();
                 ++model_iter) {
                ASSERT_TRUE(iter.Valid());
                ASSERT_EQ(*model_iter, iter.key());
                iter.Prev();
            }
            ASSERT_TRUE(!iter.Valid());
        }
    }

// We want to make sure that with a single writer and multiple
// concurrent readers (with no synchronization other than when a
// reader's iterator is created), the reader always observes all the
// data that was present in the skip list when the iterator was
// constructed.  Because insertions are happening concurrently, we may
// also observe new values that were inserted since the iterator was
// constructed, but we should never miss any values that were present
// at iterator construction time.
//
// We generate multi-part keys:
//     <key,gen,hash>
// where:
//     key is in range [0..K-1]
//     gen is a generation number for key
//     hash is hash(key,gen)
//
// The insertion code picks a random key, sets gen to be 1 + the last
// generation number inserted for that key, and sets hash to Hash(key,gen).
//
// At the beginning of a read, we snapshot the last inserted
// generation number for each key.  We then iterate, including random
// calls to Next() and Seek().  For every key we encounter, we
// check that it is either expected given the initial snapshot or has
// been concurrently added since the iterator started.
    class ConcurrentTest {
    private:
        static const uint32_t K = 4;

        static uint64_t key(Key key) { return (key >> 40); }

        static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; }

        static uint64_t hash(Key key) { return key & 0xff; }

        static uint64_t HashNumbers(uint64_t k, uint64_t g) {
            uint64_t data[2] = {k, g};
            return Hash(reinterpret_cast<char *>(data), sizeof(data), 0);
        }

        static Key MakeKey(uint64_t k, uint64_t g) {
            assert(sizeof(Key) == sizeof(uint64_t));
            assert(k <= K);  // We sometimes pass K to seek to the end of the skiplist
            assert(g <= 0xffffffffu);
            return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff));
        }

        static bool IsValidKey(Key k) {
            return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff);
        }

        static Key RandomTarget(Random *rnd) {
            switch (rnd->Next() % 10) {
                case 0:
                    // Seek to beginning
                    return MakeKey(0, 0);
                case 1:
                    // Seek to end
                    return MakeKey(K, 0);
                default:
                    // Seek to middle
                    return MakeKey(rnd->Next() % K, 0);
            }
        }

        // Per-key generation
        struct State {
            std::atomic<int> generation[K];

            void Set(int k, int v) {
                generation[k].store(v, std::memory_order_release);
            }

            int Get(int k) {
                return generation[k].load(std::memory_order_acquire);
            }

            State() {
                for (int k = 0; k < K; k++) {
                    Set(k, 0);
                }
            }
        };

        // Current state of the test
        State current_;

        Arena arena_;

        // SkipList is not protected by mu_.  We just use a single writer
        // thread to modify it.
        SkipList<Key, Comparator> list_;

    public:
        ConcurrentTest() : list_(Comparator(), &arena_) {}

        // REQUIRES: External synchronization
        void WriteStep(Random *rnd) {
            const uint32_t k = rnd->Next() % K;
            const intptr_t g = current_.Get(k) + 1;
            const Key key = MakeKey(k, g);
            list_.Insert(key);
            current_.Set(k, g);
        }

        void ReadStep(Random *rnd) {
            // Remember the initial committed state of the skiplist.
            State initial_state;
            for (int k = 0; k < K; k++) {
                initial_state.Set(k, current_.Get(k));
            }

            Key pos = RandomTarget(rnd);
            SkipList<Key, Comparator>::Iterator iter(&list_);
            iter.Seek(pos);
            while (true) {
                Key current;
                if (!iter.Valid()) {
                    current = MakeKey(K, 0);
                } else {
                    current = iter.key();
                    ASSERT_TRUE(IsValidKey(current)) << current;
                }
                ASSERT_LE(pos, current) << "should not go backwards";

                // Verify that everything in [pos,current) was not present in
                // initial_state.
                while (pos < current) {
                    ASSERT_LT(key(pos), K) << pos;

                    // Note that generation 0 is never inserted, so it is ok if
                    // <*,0,*> is missing.
                    ASSERT_TRUE((gen(pos) == 0) ||
                                (gen(pos) > static_cast<Key>(initial_state.Get(key(pos))))
                    ) << "key: " << key(pos)
                      << "; gen: " << gen(pos)
                      << "; initgen: "
                      << initial_state.Get(key(pos));

                    // Advance to next key in the valid key space
                    if (key(pos) < key(current)) {
                        pos = MakeKey(key(pos) + 1, 0);
                    } else {
                        pos = MakeKey(key(pos), gen(pos) + 1);
                    }
                }

                if (!iter.Valid()) {
                    break;
                }

                if (rnd->Next() % 2) {
                    iter.Next();
                    pos = MakeKey(key(pos), gen(pos) + 1);
                } else {
                    Key new_target = RandomTarget(rnd);
                    if (new_target > pos) {
                        pos = new_target;
                        iter.Seek(new_target);
                    }
                }
            }
        }
    };

    const uint32_t ConcurrentTest::K;

// Simple test that does single-threaded testing of the ConcurrentTest
// scaffolding.
    TEST(SkipTest, ConcurrentWithoutThreads) {
        ConcurrentTest test;
        Random rnd(test::RandomSeed());
        for (int i = 0; i < 10000; i++) {
            test.ReadStep(&rnd);
            test.WriteStep(&rnd);
        }
    }

    class TestState {
    public:
        ConcurrentTest t_;
        int seed_;
        std::atomic<bool> quit_flag_;

        enum ReaderState {
            STARTING,
            RUNNING,
            DONE
        };

        explicit TestState(int s)
                : seed_(s),
                  quit_flag_(false),
                  state_(STARTING),
                  state_cv_(&mu_) {}

        void Wait(ReaderState s) LOCKS_EXCLUDED(mu_) {
            mu_.Lock();
            while (state_ != s) {
                state_cv_.Wait();
            }
            mu_.Unlock();
        }

        void Change(ReaderState s) LOCKS_EXCLUDED(mu_) {
            mu_.Lock();
            state_ = s;
            state_cv_.Signal();
            mu_.Unlock();
        }

    private:
        port::Mutex mu_;
        ReaderState state_ GUARDED_BY(mu_);
        port::CondVar state_cv_ GUARDED_BY(mu_);
    };

    static void ConcurrentReader(void *arg) {
        TestState *state = reinterpret_cast<TestState *>(arg);
        Random rnd(state->seed_);
        int64_t reads = 0;
        state->Change(TestState::RUNNING);
        while (!state->quit_flag_.load(std::memory_order_acquire)) {
            state->t_.ReadStep(&rnd);
            ++reads;
        }
        state->Change(TestState::DONE);
    }

    static void RunConcurrent(int run) {
        const int seed = test::RandomSeed() + (run * 100);
        Random rnd(seed);
        const int N = 1000;
        const int kSize = 1000;
        for (int i = 0; i < N; i++) {
            if ((i % 100) == 0) {
                fprintf(stderr, "Run %d of %d\n", i, N);
            }
            TestState state(seed + 1);
            Env::Default()->Schedule(ConcurrentReader, &state);
            state.Wait(TestState::RUNNING);
            for (int i = 0; i < kSize; i++) {
                state.t_.WriteStep(&rnd);
            }
            state.quit_flag_.store(true, std::memory_order_release);
            state.Wait(TestState::DONE);
        }
    }

    TEST(SkipTest, Concurrent1) { RunConcurrent(1); }

    TEST(SkipTest, Concurrent2) { RunConcurrent(2); }

    TEST(SkipTest, Concurrent3) { RunConcurrent(3); }

    TEST(SkipTest, Concurrent4) { RunConcurrent(4); }

    TEST(SkipTest, Concurrent5) { RunConcurrent(5); }

}  // namespace leveldb

int main(int argc, char **argv) {
    return leveldb::test::RunAllTests();
}
